import os, sys
import cv2
import glob
import shutil
import numpy as np
from tqdm import tqdm
from imageio import imread, imsave

from src.dain_model.base_predictor import BasePredictor

DAIN_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DAIN_weight.tar'


def video2frames(video_path, outpath, **kargs):
    def _dict2str(kargs):
        cmd_str = ''
        for k, v in kargs.items():
            cmd_str += (' ' + str(k) + ' ' + str(v))
        return cmd_str

    ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
    vid_name = os.path.basename(video_path).split('.')[0]
    out_full_path = os.path.join(outpath, vid_name)

    if not os.path.exists(out_full_path):
        os.makedirs(out_full_path)

    # video file name
    outformat = os.path.join(out_full_path, '%08d.png')

    cmd = ffmpeg + [' -i ', video_path, ' -start_number ', ' 0 ', outformat]

    cmd = ''.join(cmd) + _dict2str(kargs)

    if os.system(cmd) != 0:
        raise RuntimeError('ffmpeg process video: {} error'.format(vid_name))

    sys.stdout.flush()
    return out_full_path


def frames2video(frame_path, video_path, r, w, h):
    out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'DIVX'), int(r), (w, h))
    filelist = sorted(glob.glob(os.path.join(frame_path, '*.png')))
    for img_path in filelist:
        img = cv2.imread(img_path)
        out.write(img)
    out.release()


class DAINPredictor(BasePredictor):
    def __init__(self,
                 output='output',
                 weight_path=None,
                 time_step=None,
                 use_gpu=True,
                 remove_duplicates=False):
        self.output_path = os.path.join(output, 'DAIN')
        self.weight_path = weight_path
        self.time_step = time_step
        self.key_frame_thread = 0
        self.remove_duplicates = remove_duplicates

        self.build_inference_model()

    def run(self, video_path):
        cmd='rm -rf {}/*'.format(self.output_path)
        os.system(cmd)
        frame_path_input = os.path.join(self.output_path, 'frames-input')
        frame_path_interpolated = os.path.join(self.output_path,
                                               'frames-interpolated')
        frame_path_combined = os.path.join(self.output_path, 'frames-combined')
        video_path_output = os.path.join(self.output_path, 'videos-output')

        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)
        if not os.path.exists(frame_path_input):
            os.makedirs(frame_path_input)
        if not os.path.exists(frame_path_interpolated):
            os.makedirs(frame_path_interpolated)
        if not os.path.exists(frame_path_combined):
            os.makedirs(frame_path_combined)
        if not os.path.exists(video_path_output):
            os.makedirs(video_path_output)

        timestep = self.time_step
        num_frames = int(1.0 / timestep) - 1

        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        print("Old fps (frame rate): ", fps)

        times_interp = int(1.0 / timestep)
        r2 = str(int(fps) * times_interp)
        print("New fps (frame rate): ", r2)

        out_path = video2frames(video_path, frame_path_input)

        vidname = os.path.basename(video_path).split('.')[0]

        frames = sorted(glob.glob(os.path.join(out_path, '*.png')))

        if self.remove_duplicates:
            frames = self.remove_duplicate_frames(out_path)

        img = imread(frames[0])

        int_width = img.shape[1]
        int_height = img.shape[0]
        channel = img.shape[2]
        if not channel == 3:
            return

        if int_width != ((int_width >> 7) << 7):
            int_width_pad = (((int_width >> 7) + 1) << 7)  # more than necessary
            padding_left = int((int_width_pad - int_width) / 2)
            padding_right = int_width_pad - int_width - padding_left
        else:
            padding_left = 32
            padding_right = 32

        if int_height != ((int_height >> 7) << 7):
            int_height_pad = (
                    ((int_height >> 7) + 1) << 7)  # more than necessary
            padding_top = int((int_height_pad - int_height) / 2)
            padding_bottom = int_height_pad - int_height - padding_top
        else:
            padding_top = 32
            padding_bottom = 32

        frame_num = len(frames)

        if not os.path.exists(os.path.join(frame_path_interpolated, vidname)):
            os.makedirs(os.path.join(frame_path_interpolated, vidname))
        if not os.path.exists(os.path.join(frame_path_combined, vidname)):
            os.makedirs(os.path.join(frame_path_combined, vidname))

        for i in tqdm(range(frame_num - 1)):
            first = frames[i]
            second = frames[i + 1]
            first_index = int(first.split(os.sep)[-1].split('.')[-2])
            second_index = int(second.split(os.sep)[-1].split('.')[-2])

            img_first = imread(first)
            img_second = imread(second)

            X0 = img_first.astype('float32').transpose((2, 0, 1)) / 255
            X1 = img_second.astype('float32').transpose((2, 0, 1)) / 255

            assert (X0.shape[1] == X1.shape[1])
            assert (X0.shape[2] == X1.shape[2])

            X0 = np.pad(X0, ((0, 0), (padding_top, padding_bottom), \
                             (padding_left, padding_right)), mode='edge')
            X1 = np.pad(X1, ((0, 0), (padding_top, padding_bottom), \
                             (padding_left, padding_right)), mode='edge')

            X0 = np.expand_dims(X0, axis=0)
            X1 = np.expand_dims(X1, axis=0)

            X0 = np.expand_dims(X0, axis=0)
            X1 = np.expand_dims(X1, axis=0)

            X = np.concatenate((X0, X1), axis=0)

            o = self.base_forward(X)

            y_ = o[0]

            y_ = [
                np.transpose(
                    255.0 * item.clip(
                        0, 1.0)[0, :, padding_top:padding_top + int_height,
                            padding_left:padding_left + int_width],
                    (1, 2, 0)) for item in y_
            ]
            if self.remove_duplicates:
                num_frames = times_interp * (second_index - first_index) - 1
                time_offsets = [
                    kk * timestep for kk in range(1, 1 + num_frames, 1)
                ]
                start = times_interp * first_index + 1
                for item, time_offset in zip(y_, time_offsets):
                    out_dir = os.path.join(frame_path_interpolated, vidname,
                                           "{:08d}.png".format(start))
                    imsave(out_dir, np.round(item).astype(np.uint8))
                    start = start + 1

            else:
                time_offsets = [
                    kk * timestep for kk in range(1, 1 + num_frames, 1)
                ]

                count = 1
                for item, time_offset in zip(y_, time_offsets):
                    out_dir = os.path.join(
                        frame_path_interpolated, vidname,
                        "{:0>6d}_{:0>4d}.png".format(i, count))
                    count = count + 1
                    imsave(out_dir, np.round(item).astype(np.uint8))

        input_dir = os.path.join(frame_path_input, vidname)
        interpolated_dir = os.path.join(frame_path_interpolated, vidname)
        combined_dir = os.path.join(frame_path_combined, vidname)

        if self.remove_duplicates:
            self.combine_frames_with_rm(input_dir, interpolated_dir,
                                        combined_dir, times_interp)

        else:
            num_frames = int(1.0 / timestep) - 1
            self.combine_frames(input_dir, interpolated_dir, combined_dir,
                                num_frames)

        frame_pattern_combined = os.path.join(frame_path_combined, vidname)
        # video_pattern_output = os.path.join(video_path_output, vidname + '.mp4')
        video_pattern_output = os.path.join(video_path_output, vidname + '.avi')
        if os.path.exists(video_pattern_output):
            os.remove(video_pattern_output)
        frames2video(frame_pattern_combined, video_pattern_output, r2, int_width, int_height)

        return frame_pattern_combined, video_pattern_output

    def combine_frames(self, input, interpolated, combined, num_frames):
        frames1 = sorted(glob.glob(os.path.join(input, '*.png')))
        frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png')))
        num1 = len(frames1)
        num2 = len(frames2)

        for i in range(num1):
            src = frames1[i]
            imgname = int(src.split(os.sep)[-1].split('.')[-2])
            assert i == imgname
            dst = os.path.join(combined,
                               '{:08d}.png'.format(i * (num_frames + 1)))
            shutil.copy2(src, dst)
            if i < num1 - 1:
                try:
                    for k in range(num_frames):
                        src = frames2[i * num_frames + k]
                        dst = os.path.join(
                            combined,
                            '{:08d}.png'.format(i * (num_frames + 1) + k + 1))
                        shutil.copy2(src, dst)
                except Exception as e:
                    print(e)

    def combine_frames_with_rm(self, input, interpolated, combined,
                               times_interp):
        frames1 = sorted(glob.glob(os.path.join(input, '*.png')))
        frames2 = sorted(glob.glob(os.path.join(interpolated, '*.png')))
        num1 = len(frames1)
        num2 = len(frames2)

        for i in range(num1):
            src = frames1[i]
            index = int(src.split(os.sep)[-1].split('.')[-2])
            dst = os.path.join(combined,
                               '{:08d}.png'.format(times_interp * index))
            shutil.copy2(src, dst)

        for i in range(num2):
            src = frames2[i]
            imgname = src.split(os.sep)[-1]
            dst = os.path.join(combined, imgname)
            shutil.copy2(src, dst)

    def remove_duplicate_frames(self, paths):
        def dhash(image, hash_size=8):
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            resized = cv2.resize(gray, (hash_size + 1, hash_size))
            diff = resized[:, 1:] > resized[:, :-1]
            return sum([2 ** i for (i, v) in enumerate(diff.flatten()) if v])

        hashes = {}
        max_interp = 9
        image_paths = sorted(glob.glob(os.path.join(paths, '*.png')))
        for image_path in image_paths:
            image = cv2.imread(image_path)
            h = dhash(image)
            p = hashes.get(h, [])
            p.append(image_path)
            hashes[h] = p

        for (h, hashed_paths) in hashes.items():
            if len(hashed_paths) > 1:
                first_index = int(hashed_paths[0].split(
                    os.sep)[-1].split('.')[-2])
                last_index = int(hashed_paths[-1].split(
                    os.sep)[-1].split('.')[-2]) + 1
                gap = 2 * (last_index - first_index) - 1
                if gap > 2 * max_interp:
                    cut1 = len(hashed_paths) // 3
                    cut2 = cut1 * 2
                    for p in hashed_paths[1:cut1 - 1]:
                        os.remove(p)
                    for p in hashed_paths[cut1 + 1:cut2]:
                        os.remove(p)
                    for p in hashed_paths[cut2 + 1:]:
                        os.remove(p)
                if gap > max_interp:
                    mid = len(hashed_paths) // 2
                    for p in hashed_paths[1:mid - 1]:
                        os.remove(p)
                    for p in hashed_paths[mid + 1:]:
                        os.remove(p)
                else:
                    for p in hashed_paths[1:]:
                        os.remove(p)

        frames = sorted(glob.glob(os.path.join(paths, '*.png')))
        return frames
